{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Tutorial 2: Working with Approximate Numbers\n",
    "\n",
    "Welcome to TenSEAL's third tutorial, where we will show how to use the library for operations on encrypted real numbers. We will also present another use case for encrypted evaluations over convolutions.\n",
    "\n",
    "This tutorial is inspired by the \"Introduction to CKKS\" talk at [Microsoft Private AI Bootcamp](https://www.microsoft.com/en-us/research/event/private-ai-bootcamp).\n",
    "\n",
    "We recommend checking out the other tutorials first:\n",
    "- ['Tutorial 0 - Getting Started'](https://github.com/OpenMined/TenSEAL/blob/master/tutorials/Tutorial%200%20-%20Getting%20Started.ipynb).\n",
    "- ['Tutorial 1: Training and Evaluation of Logistic Regression on Encrypted Data'](https://github.com/OpenMined/TenSEAL/blob/master/tutorials/Tutorial%201%20-%20Training%20and%20Evaluation%20of%20Logistic%20Regression%20on%20Encrypted%20Data.ipynb).\n",
    "\n",
    "Authors:\n",
    "- Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)\n",
    "- Bogdan Cebere - Twitter: [@bcebere](https://twitter.com/bcebere)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Introduction\n",
    "\n",
    "\n",
    "\n",
    "TenSEAL is a library for doing homomorphic encryption operations on tensors. It's built on top of [Microsoft SEAL](https://github.com/Microsoft/SEAL), a C++ library implementing the BFV and CKKS homomorphic encryption schemes.\n",
    "\n",
    "\n",
    "In this tutorial, we will briefly introduce and explain the CKKS scheme, highlighting its advantages. For more in-depth explanations, be sure to check the excellent \"CKKS explained\" series:\n",
    "\n",
    "- ['Part 1, Vanilla Encoding and Decoding'](https://blog.openmined.org/ckks-explained-part-1-simple-encoding-and-decoding/).\n",
    "- ['Part 2, Full Encoding and Decoding'](https://blog.openmined.org/ckks-explained-part-2-ckks-encoding-and-decoding/).\n",
    "- ['Part 3, Encryption and Decryption'](https://blog.openmined.org/ckks-explained-part-3-encryption-and-decryption/).\n",
    "- ['Part 4, Multiplication and Relinearization'](https://blog.openmined.org/ckks-explained-part-4-multiplication-and-relinearization/).\n",
    "- ['Part 5, Rescaling'](https://blog.openmined.org/ckks-explained-part-5-rescaling/).\n",
    "\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Theory: CKKS scheme\n",
    "\n",
    "__Definition__ : Cheon-Kim-Kim-Song(CKKS) is a scheme for Leveled Homomorphic Encryption that supports approximate arithmetics over complex numbers (hence, real numbers).\n",
    "\n",
    " \n",
    " \n",
    "A high-level overview of the CKKS scheme is presented in the following diagram:\n",
    "\n",
    "<img src=\"https://blog.openmined.org/content/images/2020/08/Cryptotree_diagrams-2.svg\" alt=\"ckks-high-level\" width=\"600\"/>\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Theory: CKKS Parameters\n",
    "\n",
    "#### The scaling factor\n",
    "The first step of the CKKS scheme is encoding a vector of real numbers into a plaintext polynomial.\n",
    "\n",
    "\n",
    "The scaling factor defines the encoding precision for the binary representation of the number. Intuitively, we are talking about binary precision as pictured below:\n",
    "\n",
    "\n",
    "\n",
    "<img src=\"assets/floating_point.png\" alt=\"ckks-high-level\" width=\"400\"/>\n",
    "\n",
    "#### The polynomial modulus degree(poly_modulus_degree)\n",
    "\n",
    "The polynomial modulus($N$ in the diagram) directly affects:\n",
    " - The number of coefficients in plaintext polynomials\n",
    " - The size of ciphertext elements\n",
    " - The computational performance of the scheme (bigger is worse)\n",
    " - The security level (bigger is better).\n",
    "\n",
    "In TenSEAL, as in Microsoft SEAL, the degree of the polynomial modulus must be a power of 2 (e.g. $1024$, $2048$, $4096$, $8192$, $16384$, or $32768$).\n",
    "\n",
    "#### The coefficient modulus sizes\n",
    "\n",
    "The last parameter required for the scheme is a list of binary sizes.\n",
    "Using this list, SEAL will generate a list of primes of those binary sizes, called the coefficient modulus($q$ in the diagram).\n",
    "\n",
    "The coefficient modulus directly affects:\n",
    " - The size of ciphertext elements\n",
    " - The length of the list indicates the level of the scheme (or the number of encrypted multiplications supported).\n",
    " - The security level (bigger is worse).\n",
    " \n",
    "In TenSEAL, as in Microsoft SEAL, each of the prime numbers in the coefficient modulus must be at most 60 bits and must be congruent to 1 modulo 2*poly_modulus_degree."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Theory: CKKS Keys\n",
    "\n",
    "#### The secret key\n",
    "The secret key is used for decryption. DO NOT SHARE IT.\n",
    "\n",
    "#### The public encryption key\n",
    "The key is used for encryption in the public key encryption setup.\n",
    "\n",
    "#### The relinearization keys\n",
    "Every new ciphertext has a size of 2, and multiplying ciphertexts of sizes $K$ and $L$ results in a ciphertext of size $K+L-1$. Unfortunately, this growth in size slows down further multiplications and increases noise growth.\n",
    "\n",
    "Relinearization is the operation that reduces the size of ciphertexts back to 2. This operation requires another type of public keys, the relinearization keys created by the secret key owner. \n",
    "\n",
    "The operation is needed for encrypted multiplications. The plain multiplication is fundamentally different from normal multiplication and does not result in ciphertext size growth.\n",
    "\n",
    "#### The Galois Keys(optional)\n",
    "Galois keys are another type of public keys needed to perform encrypted vector rotation operations on batched ciphertexts. \n",
    "\n",
    "One use case for vector rotations is summing the batched vector that is encrypted."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Theory: CKKS internal operations\n",
    "These operations are automatically executed by TenSEAL, unless the user opts-out.\n",
    "\n",
    "#### Relinearization\n",
    "The operation is executed automatically by TenSEAL after each encrypted multiplication.\n",
    "\n",
    "The operations relinearize a ciphertext, reducing its size down to $2$. If the size of encrypted ciphertext is $K+1$, the given relinearization keys need to have a size of at least $K-1$.\n",
    "\n",
    "#### Rescaling\n",
    "The operation is executed automatically by TenSEAL after each encrypted or plain multiplication.\n",
    "\n",
    "The approximation error exponentially grows with the number of homomorphic multiplications.\n",
    "To overcome this problem, most HE schemes usually use a modulus-switching technique. In the case of CKKS, the modulus-switching procedure is called rescaling. Applying the rescaling algorithm after a homomorphic multiplication, the approximation error grows linearly, not exponentially.\n",
    "\n",
    "Given a ciphertext encrypted modulo $q_1...q_k$, this function switches the modulus down to $q_1...q_{k-1}$ and scales the message down accordingly.\n",
    "\n",
    "This step consumes one prime from the coefficient modulus. And when you consume all of them, you won't be able to perform more multiplications."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Setup\n",
    "\n",
    "All modules are imported here. Make sure everything is installed by running the cell below."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import torch\n",
    "from torchvision import transforms\n",
    "from random import randint\n",
    "import pickle\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from matplotlib.pyplot import imshow\n",
    "from typing import Dict\n",
    "\n",
    "import tenseal as ts"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## TenSEAL CKKS Context\n",
    "\n",
    "The first step is to create a CKKS TenSEAL context.\n",
    "\n",
    "One potential example is \n",
    "```\n",
    "ctx = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])\n",
    "```\n",
    "which specifies:\n",
    " - scheme type: ts.SCHEME_TYPE.CKKS\n",
    " - poly_modulus_degree: $8192$.\n",
    " - coeff_mod_bit_sizes: The coefficient modulus sizes, here [60, 40, 40, 60]. This means that the coefficient modulus will contain 4 primes of 60 bits, 40 bits, 40 bits, and 60 bits. \n",
    " - global_scale: the scaling factor, here set to $2^{40}$.\n",
    " - optionally, TenSEAL supports switching between the public key and symmetric key encryption. By default, we will use public-key encryption.\n",
    "\n",
    "\n",
    "By default, the relinearization keys are created, with automatic relinearization and rescaling enabled by default.\n",
    "The user can create the Galois keys by calling generate_galois_keys."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def context():\n",
    "    context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])\n",
    "    context.global_scale = pow(2, 40)\n",
    "    context.generate_galois_keys()\n",
    "    return context\n",
    "\n",
    "context = context()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Plain tensor creation\n",
    "\n",
    "PlainTensor class works as a translation layer from common tensor representations to the encrypted forms offered by TenSEAL. It is the first step required for creating an encrypted tensor using TenSEAL.\n",
    "\n",
    "Observation: This translation is also automatically done by the encrypted tensor constructors, and you can skip it.\n",
    "\n",
    "<img src=\"assets/plaintensor_indepth.png\" align=\"center\" style=\"display: block;  margin: auto;\" />\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "plain1 = ts.plain_tensor([1,2,3,4], [2,2])\n",
    "\n",
    "print(\" First tensor: Shape = {} Data = {}\".format(plain1.shape, plain1.tolist()))\n",
    "\n",
    "plain2 = ts.plain_tensor(np.array([5,6,7,8]).reshape(2,2))\n",
    "print(\" Second tensor: Shape = {} Data = {}\".format(plain2.shape, plain2.tolist()))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Theory: Encrypted tensor creation\n",
    "\n",
    "CKKS requires two operations for encrypting a new message:\n",
    "\n",
    "### CKKS Encoding and Decoding\n",
    "The operation encodes vectors of complex or real numbers into plaintext polynomials to be encrypted and computed using the CKKS scheme.\n",
    "\n",
    "If the polynomial modulus degree is $N$, then the encoding converts vectors of N/2 complex numbers into plaintext elements. Homomorphic operations performed on such encrypted vectors are applied coefficient (slot-)wise, enabling powerful SIMD functionality for computations that are vectorizable. (also known as batching)\n",
    "\n",
    "\n",
    "The following diagram shows the detailed encoding-decoding flow(credits to [Yongsoo Song, Introduction to CKKS, [Microsoft Private AI Bootcamp]](https://www.youtube.com/watch?v=iQlgeL64vfo))\n",
    "\n",
    "<img src=\"assets/ckks_encoding.png\" alt=\"ckks-high-level\" width=\"600\"/>\n",
    "\n",
    "### CKKS Encryption and Decryption\n",
    "This operation converts a plaintext polynomial to a ciphertext.\n",
    "\n",
    "\n",
    "The following diagram shows the detailed encryption-decryption flow(credits to [Yongsoo Song, Introduction to CKKS, [Microsoft Private AI Bootcamp]](https://www.youtube.com/watch?v=iQlgeL64vfo)))\n",
    "\n",
    "<img src=\"assets/ckks_encryption.png\" alt=\"ckks-high-level\" width=\"600\"/>"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Practice: Encrypted tensor creation\n",
    "\n",
    "For creating a new encrypted tensor, TenSEAL executes the encoding and encryption automatically.\n",
    "This applies to both CKKS and BFV schemes.\n",
    "\n",
    "The encrypted tensor encrypts a PlainTensor and stores the ciphertexts and shapes internally.\n",
    "\n",
    "We have a few variants of encrypted tensors:\n",
    " - **BFVVector** - for 1D integer arrays.\n",
    " - **CKKSVector** - for 1D float arrays. This version has a smaller memory footprint, but it is less flexible.\n",
    " - **CKKSTensor** - for N-dimensional float arrays. This version supports tensorial operations on encrypted data, like reshaping or broadcasting.\n",
    " \n",
    " \n",
    "<img src=\"assets/encrypted_tensor_relation.png\" align=\"center\" style=\"display: block;  margin: auto;\" />"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "encrypted_tensor1 = ts.ckks_tensor(context, plain1)\n",
    "encrypted_tensor2 = ts.ckks_tensor(context, plain2)\n",
    "\n",
    "print(\" Shape = {}\".format(encrypted_tensor1.shape))\n",
    "print(\" Encrypted Data = {}.\".format(encrypted_tensor1))\n",
    "\n",
    "\n",
    "encrypted_tensor_from_np = ts.ckks_tensor(context, np.array([5,6,7,8]).reshape([2,2]))\n",
    "print(\" Shape = {}\".format(encrypted_tensor_from_np.shape))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Basic operations\n",
    "\n",
    "The following table enumerates the operations supported by CKKS tensors variants.\n",
    "\n",
    "| Operation                    | Description                                                   |\n",
    "| --- | --- |\n",
    "| negate                       | Negate an encrypted tensor                                    |\n",
    "| square                       | Compute the square of an encrypted tensor                     |\n",
    "| power                        | Compute the power of an encrypted tensor                      |\n",
    "| add                          | Addition between two encrypted tensors                        |\n",
    "| add\\_plain                   | Addition between an encrypted tensor and a plain tensor       |\n",
    "| sub                          | Subtraction between two encrypted tensors                     |\n",
    "| sub\\_plain                   | Subtraction between an encrypted tensor and a plain tensor    |\n",
    "| mul                          | Multiplication between two encrypted tensors                  |\n",
    "| mul\\_plain                   | Multiplication between an encrypted tensor and a plain tensor |\n",
    "| dot                 | Dot product between two encrypted tensors                     |\n",
    "| dot\\_plain          | Dot product between an encrypted tensor and a plain tensor    |\n",
    "| polyval                      | Polynomial evaluation with an encrypted tensor as variable    |\n",
    "| matmul                | Multiplication between an encrypted vector and a plain matrix |\n",
    "| matmul\\_plain           | Encrypted matrix multiplication with plain vector             |\n",
    "\n",
    "\n",
    "The CKKSVector variant contains the following additional operations:\n",
    "\n",
    "\n",
    "| Operation                    | Description                                                   |\n",
    "| --- | --- |\n",
    "| conv2d\\_im2col               | Image Block to Columns                                        |\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def decrypt(enc):\n",
    "    return enc.decrypt().tolist()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Addition of two encrypted tensors."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1 + encrypted_tensor2\n",
    "print(\"Plain equivalent: {} + {}\\nDecrypted result: {}.\".format(plain1.tolist(), plain2.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Subtraction of two encrypted tensors."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1 - encrypted_tensor2\n",
    "print(\"Plain equivalent: {} - {}\\nDecrypted result: {}.\".format(plain1.tolist(), plain2.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Multiplication of two encrypted tensors. \n",
    "\n",
    "The following diagram shows the detailed flow for multiplication and relinearization (credits to [Yongsoo Song, Introduction to CKKS, [Microsoft Private AI Bootcamp]](https://www.youtube.com/watch?v=iQlgeL64vfo)))\n",
    "\n",
    "\n",
    "<img src=\"assets/ckks_mul.png\" alt=\"ckks-high-level\" width=\"600\"/>"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1 * encrypted_tensor2\n",
    "print(\"Plain equivalent: {} * {}\\nDecrypted result: {}.\".format(plain1.tolist(), plain2.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Multiplication with plain tensor"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "plain = ts.plain_tensor([5,6,7,8], [2,2])\n",
    "result = encrypted_tensor1 * plain\n",
    "\n",
    "print(\"Plain equivalent: {} * {}\\nDecrypted result: {}.\".format(plain1.tolist(), plain.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Negation"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = -encrypted_tensor1 \n",
    "\n",
    "print(\"Plain equivalent: -{}\\nDecrypted result: {}.\".format(plain1.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Power"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1 ** 3\n",
    "print(\"Plain equivalent: {} ^ 3\\nDecrypted result: {}.\".format(plain1.tolist(), decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Polynomial evaluation $1 + X^2 + X^3$"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1.polyval([1,0,1,1])\n",
    "\n",
    "print(\"X = {}\".format(plain1.tolist()))\n",
    "print(\"1 + X^2 + X^3 = {}.\".format(decrypt(result)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Sigmoid approximation\n",
    "$\\sigma(x) = 0.5 + 0.197 x - 0.004 x^3$\n",
    "\n",
    "Reference: [\"Logistic regression over encrypted data from fully homomorphic encryption\", Hao Chen et al](https://eprint.iacr.org/2018/462.pdf)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = encrypted_tensor1.polyval([0.5, 0.197, 0, -0.004])\n",
    "\n",
    "\n",
    "print(\"X = {}\".format(plain1.tolist()))\n",
    "print(\"0.5 + 0.197 X - 0.004 x^X = {}.\".format(decrypt(result)))\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Encrypted inference demo\n",
    "\n",
    "Now that we introduced the CKKS scheme let's see it in action.\n",
    "\n",
    "The next example contains a classification over the MNIST dataset using a single convolution and two fully connected layers with a square activation function.\n",
    "\n",
    "It illustrates one of the prominent use cases for homomorphic encryption, as depicted here.\n",
    "\n",
    "<img src=\"https://blog.openmined.org/content/images/2020/04/OM---CKKS-Graphic-v.01@2x.png\" align=\"center\" style=\"display: block;  margin: auto;\"/>\n",
    "\n",
    "\n",
    "Adapted from https://github.com/youben11/encrypted-evaluation"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Client Helpers"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Create the TenSEAL security context\n",
    "def create_ctx():\n",
    "    \"\"\"Helper for creating the CKKS context.\n",
    "    CKKS params:\n",
    "        - Polynomial degree: 8192.\n",
    "        - Coefficient modulus size: [40, 21, 21, 21, 21, 21, 21, 40].\n",
    "        - Scale: 2 ** 21.\n",
    "        - The setup requires the Galois keys for evaluating the convolutions.\n",
    "    \"\"\"\n",
    "    poly_mod_degree = 8192\n",
    "    coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]\n",
    "    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)\n",
    "    ctx.global_scale = 2 ** 21\n",
    "    ctx.generate_galois_keys()\n",
    "    return ctx\n",
    "\n",
    "# Sample an image\n",
    "def load_input():\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    )\n",
    "    idx = randint(1, 6)\n",
    "    img_name = \"data/mnist-samples/img_{}.jpg\".format(idx)\n",
    "    print(img_name)\n",
    "    img = Image.open(img_name)\n",
    "    return transform(img).view(28, 28).tolist(), img\n",
    "\n",
    "# Helper for encoding the image\n",
    "def prepare_input(ctx, plain_input):\n",
    "    enc_input, windows_nb = ts.im2col_encoding(ctx, plain_input, 7, 7, 3)\n",
    "    assert windows_nb == 64\n",
    "    return enc_input"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Server Model\n",
    "\n",
    " - We are using a pretrained plain model, stored in \"tutorials/parameters/ConvMNIST-0.1.pickle\"."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Load a pretrained model and adapt the forward call for encrypted input\n",
    "class ConvMNIST():\n",
    "    \"\"\"CNN for classifying MNIST data.\n",
    "    Input should be an encoded 28x28 matrix representing the image.\n",
    "    TenSEAL can be used for encoding `tenseal.im2col_encoding(ctx, input_matrix, 7, 7, 3)`\n",
    "    The input should also be normalized with a mean=0.1307 and an std=0.3081 before encryption.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, parameters: Dict[str, list]):\n",
    "        self.conv1_weight = parameters[\"conv1_weight\"]\n",
    "        self.conv1_bias = parameters[\"conv1_bias\"]\n",
    "        self.fc1_weight = parameters[\"fc1_weight\"]\n",
    "        self.fc1_bias = parameters[\"fc1_bias\"]\n",
    "        self.fc2_weight = parameters[\"fc2_weight\"]\n",
    "        self.fc2_bias = parameters[\"fc2_bias\"]\n",
    "        self.windows_nb = parameters[\"windows_nb\"]\n",
    "\n",
    "    def forward(self, enc_x: ts.CKKSVector) -> ts.CKKSVector:\n",
    "        # conv layer\n",
    "        channels = []\n",
    "        for kernel, bias in zip(self.conv1_weight, self.conv1_bias):\n",
    "            y = enc_x.conv2d_im2col(kernel, self.windows_nb) + bias\n",
    "            channels.append(y)\n",
    "        out = ts.CKKSVector.pack_vectors(channels)\n",
    "        # squaring\n",
    "        out.square_()\n",
    "        # no need to flat\n",
    "        # fc1 layer\n",
    "        out = out.mm_(self.fc1_weight) + self.fc1_bias\n",
    "        # squaring\n",
    "        out.square_()\n",
    "        # output layer\n",
    "        out = out.mm_(self.fc2_weight) + self.fc2_bias\n",
    "        return out\n",
    "\n",
    "    @staticmethod\n",
    "    def prepare_input(context: bytes, ckks_vector: bytes) -> ts.CKKSVector:\n",
    "        try:\n",
    "            ctx = ts.context_from(context)\n",
    "            enc_x = ts.ckks_vector_from(ctx, ckks_vector)\n",
    "        except:\n",
    "            raise DeserializationError(\"cannot deserialize context or ckks_vector\")\n",
    "        try:\n",
    "            _ = ctx.galois_keys()\n",
    "        except:\n",
    "            raise InvalidContext(\"the context doesn't hold galois keys\")\n",
    "\n",
    "        return enc_x\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        return self.forward(*args, **kwargs)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Server helpers"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "def load_parameters(file_path: str) -> dict:\n",
    "    try:\n",
    "        parameters = pickle.load(open(file_path, \"rb\"))\n",
    "        print(f\"Model loaded from '{file_path}'\")\n",
    "    except OSError as ose:\n",
    "        print(\"error\", ose)\n",
    "        raise ose\n",
    "    return parameters\n",
    "\n",
    "\n",
    "\n",
    "parameters = load_parameters(\"parameters/ConvMNIST-0.1.pickle\")\n",
    "model = ConvMNIST(parameters)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Client Query\n",
    "The client has to create the CKKS context for the first query.\n",
    "Then, he samples and encrypts a random image from the dataset.\n",
    "\n",
    "The serialized context and encrypted image are sent to the server for evaluation."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# CKKS context generation.\n",
    "context = create_ctx()\n",
    "\n",
    "# Random image sampling\n",
    "image, orig = load_input()\n",
    "\n",
    "# Image encoding\n",
    "encrypted_image = prepare_input(context, image)\n",
    "\n",
    "print(\"Encrypted image \", encrypted_image)\n",
    "print(\"Original image \")\n",
    "imshow(np.asarray(orig))\n",
    "\n",
    "# We prepare the context for the server, by making it public(we drop the secret key)\n",
    "server_context = context.copy()\n",
    "server_context.make_context_public()\n",
    "\n",
    "# Context and ciphertext serialization\n",
    "server_context = server_context.serialize()\n",
    "encrypted_image = encrypted_image.serialize()\n",
    "\n",
    "\n",
    "client_query = {\n",
    "    \"data\" : encrypted_image,\n",
    "    \"context\" : server_context,\n",
    "}"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Server inference\n",
    "\n",
    "The server deserializes the context and ciphertext.\n",
    "It executes the inference, serializes the result and sends it back to the client."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "encrypted_query = model.prepare_input(client_query[\"context\"], client_query[\"data\"])\n",
    "encrypted_result = model(encrypted_query).serialize()\n",
    "\n",
    "server_response = {\n",
    "    \"data\" : encrypted_result\n",
    "}"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Client process response\n",
    "\n",
    "The client deserializes and decrypts the result.\n",
    "Since we cannot run the non-linearity over the CKKSVector, we run the last softmax step on the client side.\n",
    "\n",
    "Finally, we retrieve the final result."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "result = ts.ckks_vector_from(context, server_response[\"data\"]).decrypt()\n",
    "\n",
    "probs = torch.softmax(torch.tensor(result), 0)\n",
    "label_max = torch.argmax(probs)\n",
    "print(\"Maximum probability for label {}\".format(label_max))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star TenSEAL on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star TenSEAL](https://github.com/OpenMined/TenSEAL)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org).\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## References\n",
    "\n",
    "1. Yongsoo Song, Introduction to CKKS, [Private AI Bootcamp](microsoft.com/en-us/research/event/private-ai-bootcamp/#!videos).\n",
    "2. [Microsoft SEAL](https://github.com/microsoft/SEAL).\n",
    "3. Daniel Huynh, [CKKS Explained Series](https://blog.openmined.org/ckks-explained-part-1-simple-encoding-and-decoding/)."
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
